import pandas as pd
import numpy as np
import pickle
import time
import scipy
import matplotlib.pyplot as plt
import torchvision.models as models
import torch.nn as nn
import torch

import torchvision

import argparse
import os
import random
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torch.nn.functional as F

import torchvision.models as models

import seaborn as sns

from sklearn.neighbors import KNeighborsClassifier

from tqdm import tqdm

from copy import deepcopy



sample_size = 10  # How many test images to consider
runs = 5  # How many experiment runs to aggregate
DEVICE = 'cpu'
DATAROOT = 'data'


def main():
    run()


def load_dataloaders(b_size=32, shuffle=False):
	train_transform = transforms.Compose(
		[transforms.ToTensor(),
		 transforms.Normalize((0.5,), (0.5,), (0.5,))])

	test_transform = transforms.Compose(
		[transforms.ToTensor(),
		 transforms.Normalize((0.5,), (0.5,), (0.5,))])

	train_set = torchvision.datasets.FashionMNIST(
		root='./data/FashionMNIST',
		train=True,
		download=True,
		transform=train_transform)

	train_loader = torch.utils.data.DataLoader(
		train_set,
		batch_size=b_size,
		shuffle=shuffle,
    num_workers=2)

	test_set = torchvision.datasets.FashionMNIST(
		root='./data/FashionMNIST',
		train=False,
		download=True,
		transform=test_transform)

	test_loader = torch.utils.data.DataLoader(
		test_set,
		batch_size=b_size,
		shuffle=False,
    num_workers=2)
 
	return train_loader, test_loader


class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()

        # input is Z, going into a convolution
        self.conv1 = nn.Conv2d(1, 8, kernel_size=5, stride=1, padding=2)
        self.bn1   = nn.BatchNorm2d(8)

        self.conv2 = nn.Conv2d(8, 16, kernel_size=5, stride=2, padding=2)
        self.bn2 = nn.BatchNorm2d(16)

        self.conv3 = nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2)
        self.bn3 = nn.BatchNorm2d(32)

        self.conv4 = nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=2)
        self.bn4 = nn.BatchNorm2d(64)

        self.conv5 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.bn5 = nn.BatchNorm2d(128)
        
        self.avgpool = nn.AvgPool2d(7)
        self.linear = nn.Linear(128, 10)    
        self.relu = nn.ReLU()
        
    def forward(self, I):   
        x = self.relu(self.bn1(self.conv1(I)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.bn3(self.conv3(x)))
        x = self.relu(self.bn4(self.conv4(x)))
        C = self.relu(self.bn5(self.conv5(x)))
        x = self.avgpool(C)
        x = x.view(x.shape[0], x.shape[1])
        logits = self.linear(x)
        return logits, x, C


### Pre-Requisites for Algorithm
def load_cnn():
    netC = CNN()
    netC.load_state_dict(torch.load('weights/cnn.pth', map_location=torch.device('cpu')))
    netC = netC.eval().to(DEVICE)
    return netC


class netClassifier(nn.Module):
    
    def __init__(self, netC):
        super(netClassifier, self).__init__()
        self.net = netC
        
    def forward(self, C):
        x = self.net.avgpool(C)
        x = x.view(-1, 128)
        logits = self.net.linear(x)
        return logits



def evaluate_expt1(la, lb, df, tech, rand_layer):
    """
    Collect Data comparing NN sets
    """
    
    nns = [1, 5, 10, 50, 100, 200, 500, 1000]
    for i in tqdm(range(len(la))):
        for nn in nns:
            
            row = pd.DataFrame(columns=['Overlap', 'Technique', 'NNs', 'LayerRand'])
            
            a = la[i][:nn]
            b = lb[i][:nn]

            overlap = len(set(a) & set(b)) / len(a)
            
            row['Overlap'] = [overlap]
            row['Technique'] = tech
            row['NNs'] = nn
            row['LayerRand'] = rand_layer
            
            df = pd.concat([df, row])
            
    return df


def grad_cos(net, I, label):
    """
    Implement Grad-Cos
    Return: gradient of output loss classificaiton w.r.t. last linear layer x
    """
    
    net.zero_grad()
    loss_fn = torch.nn.CrossEntropyLoss()
    logits, x, C = net(I)
    loss  = loss_fn(logits, label)
    grad = torch.autograd.grad(loss, x, retain_graph=True, create_graph=True)
    return grad[0]


def run():

    netC = load_cnn()
    net_classifier = netClassifier(netC).eval().to(DEVICE)
    WEIGHTS = netC.linear.weight
    train_loader, test_loader = load_dataloaders(b_size=1)

    X_train = train_loader.dataset.data
    y_train = train_loader.dataset.targets
    X_test = test_loader.dataset.data
    y_test = test_loader.dataset.targets


    X_train_c = np.load(DATAROOT + "/X_train_cont.npy")
    X_test_c = np.load(DATAROOT + "/X_test_cont.npy")
    X_train_C = np.load(DATAROOT + "/X_train_conv.npy")
    X_test_C = np.load(DATAROOT + "/X_test_conv.npy")
    X_train_x = np.load(DATAROOT + "/X_train_x.npy")
    X_test_x = np.load(DATAROOT + "/X_test_x.npy")
    train_preds = np.load(DATAROOT + "/X_train_y.npy")
    test_preds = np.load(DATAROOT + "/X_test_y.npy")

    
    num_layers = 0
    for name, param in netC.named_parameters():
        num_layers += 1


    # Fit k-NNs with Methods

    # Fit COLE(twin-system) 
    twin = KNeighborsClassifier(n_neighbors=1, algorithm="brute", metric='euclidean') 
    twin.fit(X_train_c, train_preds)


    # Fit Grad-Cos
    temp_train = list()

    for data in tqdm(train_loader):
        img, label = data
        x = grad_cos(netC, img, label)
        x = x[0].detach().numpy()
        
        if x.sum() == 0.:
            temp_train.append(x)
        else:
            x = x / np.linalg.norm(x)
            temp_train.append(x)
        
    temp_train = np.array(temp_train)

    grad_twin = KNeighborsClassifier(n_neighbors=1, algorithm="brute", metric='euclidean') 
    grad_twin.fit(temp_train, y_train)


    # Fit DkNN
    temp_train = list()

    for i in range(X_train_x.shape[0]):
        x = deepcopy(X_train_x[i].flatten())
        if x.sum() == 0.:
            temp_train.append(x)
        else:
            x = x / np.linalg.norm(x)
            temp_train.append(x)
        
    temp_train = np.array(temp_train)

    # Fit DkNN
    dknn_twin = KNeighborsClassifier(n_neighbors=1, algorithm="brute", metric='euclidean') 
    dknn_twin.fit(temp_train, y_train)


    # Fit ExMatchina
    temp_train = list()

    for i in range(X_train_C[:60000].shape[0]):
        x = deepcopy(X_train_C[i].flatten())
        if x.sum() == 0.:
            temp_train.append(x)
        else:
            x = x / np.linalg.norm(x)
            temp_train.append(x)
                
    temp_train = np.array(temp_train)

    # Fit Exmatchina
    exmatchina_twin = KNeighborsClassifier(n_neighbors=1, algorithm="brute", metric='euclidean') 
    exmatchina_twin.fit(temp_train, y_train)


    WEIGHTS = netC.linear.weight


    # # Main Experiment Loop

    ## Weight Randomization Last Layer

    for run in range(runs):

        df = pd.DataFrame(columns=['Overlap', 'Technique', 'NNs', 'LayerRand'])

        for rand_f in ['linear', 'Conv1', 'ConvHalf', 'AllCNN']:
            if rand_f == 'linear':
                netC2 = load_cnn()  # get "new" cnn
                count = 0
                rand_past_layer = num_layers - 2
                for name, param in netC2.named_parameters():
                    if count >= rand_past_layer:
                        param.data = torch.randn(param.shape).requires_grad_(True)
                        print("Changed This Layer:", name)
                    else:
                        count += 1
                WEIGHTS2 = netC2.linear.weight        
                
            elif rand_f == 'Conv1':
                netC2 = load_cnn()  # get "new" cnn
                count = 0
                rand_up_to_layer = 1
                for name, param in netC2.named_parameters():
                    if count <= rand_up_to_layer:
                        param.data = torch.randn(param.shape)
                        print("Changed This Layer:", name)
                    count += 1
                WEIGHTS2 = netC2.linear.weight        

            elif rand_f == 'ConvHalf':
                netC2 = load_cnn()  # get "new" cnn
                count = 0
                rand_up_to_layer = 11
                for name, param in netC2.named_parameters():
                    if count <= rand_up_to_layer:
                        param.data = torch.randn(param.shape)
                        print("Changed This Layer:", name)
                    count += 1
                WEIGHTS2 = netC2.linear.weight        

            else: 
                netC2 = load_cnn()  # get "new" cnn
                count = 0
                rand_up_to_layer = 30
                for name, param in netC2.named_parameters():
                    if count <= rand_up_to_layer:
                        param.data = torch.randn(param.shape)
                        print("Changed This Layer:", name)
                    count += 1
                WEIGHTS2 = netC2.linear.weight    



            # Twin-System
            #### First collect NN and query box information from Original Weights

            org_nns = list()
            rand_nns = list()


            for query_idx, data in enumerate(tqdm(test_loader)):
                
                # Get data
                img, label = data
                img, label = img.to(DEVICE), label.to(DEVICE)
                
                # Get query information
                query_label = y_test[query_idx].item()
                    
                    
                #### Real CNN
                query_logits, query_x, query_C = netC(img)
                query_pred = torch.argmax(query_logits, dim=1)[0].item()
                query_cont = WEIGHTS[query_pred] * query_x[0]
                nn_idxs = twin.kneighbors(X=[query_cont.detach().numpy()], n_neighbors=1000, return_distance=False)[0]
                org_nns.append(nn_idxs.tolist())
                
                #### Rand CNN
                query_logits, query_x, query_C = netC2(img)
                query_pred = torch.argmax(query_logits, dim=1)[0].item()
                query_cont = WEIGHTS2[query_pred] * query_x[0]
                nn_idxs = twin.kneighbors(X=[query_cont.detach().numpy()], n_neighbors=1000, return_distance=False)[0]
                rand_nns.append(nn_idxs.tolist())
                
                if query_idx == sample_size:
                    break




            df = evaluate_expt1(org_nns, rand_nns, df, 'Twin-System', rand_f)


            # Cosine_Grad
            #### First collect NN and query box information from Original Weights

            org_nns = list()
            rand_nns = list()


            for query_idx, data in enumerate(tqdm(test_loader)):
                
                # Get data
                img, label = data
                img, label = img.to(DEVICE), label.to(DEVICE)
                    
                #### Normal CNN
                x = grad_cos(netC, img, label)
                x = x[0].detach().numpy()
                if x.sum() == 0.:
                    pass
                else:
                    x = x / np.linalg.norm(x)
                
                nn_idxs = grad_twin.kneighbors(X=[x], n_neighbors=1000, return_distance=False)[0]
                org_nns.append(nn_idxs.tolist())
                
                #### Random CNN
                x = grad_cos(netC2, img, label)
                x = x[0].detach().numpy()
                if x.sum() == 0.:
                    pass
                else:
                    x = x / np.linalg.norm(x)
                
                nn_idxs = grad_twin.kneighbors(X=[x], n_neighbors=1000, return_distance=False)[0]
                rand_nns.append(nn_idxs.tolist())
                
                if query_idx == sample_size:
                    break

            df = evaluate_expt1(org_nns, rand_nns, df, 'Grad-Cos', rand_f)


            # DkNN 
            #### First collect NN and query box information from Original Weights

            org_nns = list()
            rand_nns = list()


            for query_idx, data in enumerate(tqdm(test_loader)):
                
                # Get data
                img, label = data
                img, label = img.to(DEVICE), label.to(DEVICE)
                
                # Get query information
                query_label = y_test[query_idx].item()
                    
                #### Real CNN
                query_logits, query_x, query_C = netC(img)
                query_x = query_x[0].detach().numpy()
                
                # Cosine Similarity
                if query_x.sum() == 0.:
                    pass
                else:
                    query_x = query_x / np.linalg.norm(query_x)
                
                nn_idxs = dknn_twin.kneighbors(X=[query_x], n_neighbors=1000, return_distance=False)[0]
                org_nns.append(nn_idxs.tolist())
                
                #### Random CNN
                query_logits, query_x, query_C = netC2(img)
                query_x = query_x[0].detach().numpy()
                
                # Cosine Similarity
                if query_x.sum() == 0.:
                    pass
                else:
                    query_x = query_x / np.linalg.norm(query_x)
                
                nn_idxs = dknn_twin.kneighbors(X=[query_x], n_neighbors=1000, return_distance=False)[0]
                rand_nns.append(nn_idxs.tolist())
                
                if query_idx == sample_size:
                    break

            df = evaluate_expt1(org_nns, rand_nns, df, 'DkNN', rand_f)

            # ExMatchina 
            org_nns = list()
            rand_nns = list()

            for query_idx, data in enumerate(tqdm(test_loader)):
                
                # Get data
                img, label = data
                img, label = img.to(DEVICE), label.to(DEVICE)
                        
                #### Real CNN
                query_logits, query_x, query_C = netC(img)
                query_C = query_C.flatten().detach().numpy()
                
                # Cosine Similarity
                if query_C.sum() == 0.:
                    pass
                else:
                    query_C = query_C / np.linalg.norm(query_C)
                
                nn_idxs = exmatchina_twin.kneighbors(X=[query_C], n_neighbors=1000, return_distance=False)[0]
                org_nns.append(nn_idxs.tolist())
                
                #### Random CNN
                query_logits, query_x, query_C = netC2(img)
                query_C = query_C.flatten().detach().numpy()
                
                # Cosine Similarity
                if query_C.sum() == 0.:
                    pass
                else:
                    query_C = query_C / np.linalg.norm(query_C)
                
                nn_idxs = exmatchina_twin.kneighbors(X=[query_C], n_neighbors=1000, return_distance=False)[0]
                rand_nns.append(nn_idxs.tolist())
                
                if query_idx == sample_size:
                    break

            df = evaluate_expt1(org_nns, rand_nns, df, 'ExMatchina', rand_f)


        df['Technique'] = df['Technique'].replace({'twin-system':'Twin-System'})
        df['LayerRand'] = df['LayerRand'].replace({'linear':'Randomize Last Linear Layer',
                                                  'Conv1':'Randomize First Convolutional Layer',
                                                  'ConvHalf':'Randomize Half of Convolutions',
                                                  'AllCNN':'Randomize Entire CNN'})

        df = df.rename(columns={"NNs": "Number of NNs Considered"})

        #### Save 5 runs of experiment
        df.to_csv('test' + str(run) + '.csv')


    ### Load for final analysis
    df1 = pd.read_csv('test0.csv')
    df2 = pd.read_csv('test1.csv')
    df3 = pd.read_csv('test2.csv')
    df4 = pd.read_csv('test3.csv')
    df5 = pd.read_csv('test4.csv')
    df = pd.concat([df1, df2, df3, df4, df5])

    for l in df.LayerRand.value_counts().keys():
        temp = df[df.LayerRand == l]
        g = sns.lineplot(x="Number of NNs Considered", y='Overlap', hue='Technique', data=temp,
                        err_style="bars", ci=68)
        plt.title(l)
        plt.savefig('figs/' + l + '.pdf')
        plt.close()


if __name__ == '__main__':
    main()



